from human_pref import HumanPrefExperimentConfig
from clustering import ClusteringExperimentConfig
from robustness import RobustnessExperimentConfig
from sensitivity import SensitivityExperimentConfig
from retrieval import RetrievalExperimentConfig


def main():
    models = [
        "BAAI/bge-m3",
        "nomic-ai/nomic-embed-text-v1.5",
        "jinaai/jina-embeddings-v2-base-en",
        "text-embedding-3-small",
        "embed-english-v3.0",
    ]

    max_length = 8192

    # Human Preference
    modes = ["scores", "comparisons"]

    queue = []
    for mode in modes:
        for model in models:
            queue.append(
                {
                    "mode": mode,
                    "num_examples": 10000,
                    "model_name": model,
                    "max_length": max_length,
                }
            )

    for item in queue:
        config = HumanPrefExperimentConfig(**item)
        config.run()

    # Sensitivity
    modes = [
        "insert",
        "remove",
    ]

    datasets = [
        "paul_graham",
        "amazon_polarity",
        "arguana",
        "scientific_papers_sensitivity",
        "reddit",
    ]

    queue = []
    for mode in modes:
        if mode == "insert":
            needle_sizes = [0.05, 0.1, 0.2, 0.5, 1]
            needle_keywords = ["lorem"]
        else:
            needle_sizes = [0.05, 0.1, 0.2, 0.5]
            needle_keywords = [None]

        for model in models:
            for dataset in datasets:
                queue.append(
                    {
                        "mode": mode,
                        "dataset_name": dataset,
                        "num_examples": 10000,
                        "needle_keywords": needle_keywords,
                        "needle_sizes": needle_sizes,
                        "needle_posns": [0, 0.5, 1],
                        "model_name": model,
                        "max_length": max_length,
                    }
                )

    for item in queue:
        config = SensitivityExperimentConfig(**item)
        config.run()

    # Robustness
    datasets = [
        "scientific_papers_robustness",
        "cnn_dailymail",
        "billsum",
    ]

    queue = []
    for model in models:
        for dataset in datasets:
            queue.append(
                {
                    "dataset_name": dataset,
                    "num_examples": 10000,
                    "model_name": model,
                    "max_length": max_length,
                }
            )

    for item in queue:
        config = RobustnessExperimentConfig(**item)
        config.run()

    # Clustering
    datasets = [
        "biorxiv-clustering-p2p",
        "stackexchange-clustering-p2p",
        "reddit-clustering-p2p",
        "arxiv-clustering-p2p",
        "medrxiv-clustering-p2p",
        "biorxiv-clustering-s2s",
        "stackexchange-clustering",
        "reddit-clustering",
        "arxiv-clustering-s2s",
        "medrxiv-clustering-s2s",
        "twentynewsgroups-clustering",
    ]

    queue = []
    for model in models:
        for dataset in datasets:
            queue.append(
                {
                    "dataset_name": dataset,
                    "num_examples": 10000,
                    "model_name": model,
                    "max_length": max_length,
                }
            )

    for item in queue:
        config = ClusteringExperimentConfig(**item)
        config.run()

    # Retrieval
    datasets = [
        "trec-covid",
        "hotpotqa",
        "msmarco-v2",
        "nq",
        "quora-retrieval",
        "dbpedia",
        "cqadupstack-webmasters",
        "climate-fever",
        "fiqa",
        "nfcorpus",
        "touche2020",
        "scifact",
        "scidocs",
    ]

    queue = []
    for model in models:
        for dataset in datasets:
            queue.append(
                {
                    "dataset_name": dataset,
                    "num_examples": 10000,
                    "model_name": model,
                    "max_length": max_length,
                }
            )

    for item in queue:
        config = RetrievalExperimentConfig(**item)
        config.run()


if __name__ == "__main__":
    main()
